{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8054.86it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:03<00:00, 4800.39it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 778340.21it/s]\n",
      "Building FrequencyFilterUnit from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 130441.55it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 70890.36it/s] \n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 550507.27it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 684050.05it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 2774466.49it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8773.78it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:03<00:00, 4809.19it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 120711.33it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 160078.13it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 118142.61it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 525715.22it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 687788.90it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 115610.83it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 91451.28it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=40, remove_stop_words=False)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 122/122 [00:00<00:00, 8070.42it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 1115/1115 [00:00<00:00, 4889.07it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 117799.72it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 102095.99it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 120594.35it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 216182.97it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 489804.04it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 70058.20it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 76144.60it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 237/237 [00:00<00:00, 9151.80it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2300/2300 [00:00<00:00, 4830.06it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 130539.91it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 167011.10it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 128931.32it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 241626.17it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 606875.89it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 98187.48it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 91073.78it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss())\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter \"name\" set to MVLSTM.\n",
      "Parameter \"embedding_trainable\" set to True.\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 10)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 40)           0                                            \n",
      "__________________________________________________________________________________________________\n",
      "embedding (Embedding)           multiple             1667400     text_left[0][0]                  \n",
      "                                                                 text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "bidirectional_1 (Bidirectional) (None, 10, 100)      60400       embedding[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "bidirectional_2 (Bidirectional) (None, 40, 100)      60400       embedding[1][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, 10, 40)       0           bidirectional_1[0][0]            \n",
      "                                                                 bidirectional_2[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "reshape_1 (Reshape)             (None, 400)          0           dot_1[0][0]                      \n",
      "__________________________________________________________________________________________________\n",
      "lambda_1 (Lambda)               (None, 20)           0           reshape_1[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 10)           210         lambda_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 10)           110         dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 5)            55          dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dropout_1 (Dropout)             (None, 5)            0           dense_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 1)            6           dropout_1[0][0]                  \n",
      "==================================================================================================\n",
      "Total params: 1,788,581\n",
      "Trainable params: 1,788,581\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.MVLSTM()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n",
    "model.params['embedding_output_dim'] = 100\n",
    "model.params['lstm_units'] = 50\n",
    "model.params['top_k'] = 20\n",
    "model.params['mlp_num_layers'] = 2\n",
    "model.params['mlp_num_units'] = 10\n",
    "model.params['mlp_num_fan_out'] = 5\n",
    "model.params['mlp_activation_func'] = 'relu'\n",
    "model.params['dropout_rate'] = 0.5\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=100)\n",
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = predict_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "102"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=2, num_neg=1, batch_size=20)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "102/102 [==============================] - 12s 118ms/step - loss: 1.0106\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4395442884097643 - normalized_discounted_cumulative_gain@5(0): 0.4901664416079868 - mean_average_precision(0): 0.46504195849436275\n",
      "Epoch 2/30\n",
      "102/102 [==============================] - 9s 85ms/step - loss: 0.9472\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.40495809628544793 - normalized_discounted_cumulative_gain@5(0): 0.47519944198098163 - mean_average_precision(0): 0.44288883214230373\n",
      "Epoch 3/30\n",
      "102/102 [==============================] - 9s 88ms/step - loss: 0.9258\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.3679848819682749 - normalized_discounted_cumulative_gain@5(0): 0.45307045673492513 - mean_average_precision(0): 0.4184850223054166\n",
      "Epoch 4/30\n",
      "102/102 [==============================] - 9s 85ms/step - loss: 0.8885\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.48446984913747804 - normalized_discounted_cumulative_gain@5(0): 0.5426550162890549 - mean_average_precision(0): 0.5061456854887297\n",
      "Epoch 5/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.8577\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4311335856085815 - normalized_discounted_cumulative_gain@5(0): 0.5040511533341872 - mean_average_precision(0): 0.453401456228557\n",
      "Epoch 6/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.8481\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4964927588142032 - normalized_discounted_cumulative_gain@5(0): 0.5590616576327969 - mean_average_precision(0): 0.5189103169060975\n",
      "Epoch 7/30\n",
      "102/102 [==============================] - 9s 90ms/step - loss: 0.8141\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5155367064254569 - normalized_discounted_cumulative_gain@5(0): 0.5945794798476091 - mean_average_precision(0): 0.542255642983181\n",
      "Epoch 8/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.7849\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5350301005091076 - normalized_discounted_cumulative_gain@5(0): 0.6036622840217906 - mean_average_precision(0): 0.5492099553336506\n",
      "Epoch 9/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.7870\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5451886915054126 - normalized_discounted_cumulative_gain@5(0): 0.6102707067225249 - mean_average_precision(0): 0.5632031659141037\n",
      "Epoch 10/30\n",
      "102/102 [==============================] - 9s 88ms/step - loss: 0.7582\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5490193863654211 - normalized_discounted_cumulative_gain@5(0): 0.6121719184285449 - mean_average_precision(0): 0.5636763882358322\n",
      "Epoch 11/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.7627\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5482839385805172 - normalized_discounted_cumulative_gain@5(0): 0.6123828329665424 - mean_average_precision(0): 0.5610809299725662\n",
      "Epoch 12/30\n",
      "102/102 [==============================] - 9s 92ms/step - loss: 0.7090\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5579092543626675 - normalized_discounted_cumulative_gain@5(0): 0.626377814854734 - mean_average_precision(0): 0.574857140937906\n",
      "Epoch 13/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.7591\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5720464151378393 - normalized_discounted_cumulative_gain@5(0): 0.6493450327088173 - mean_average_precision(0): 0.5879686359116739\n",
      "Epoch 14/30\n",
      "102/102 [==============================] - 9s 91ms/step - loss: 0.7299\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5555904445024372 - normalized_discounted_cumulative_gain@5(0): 0.6240198935000864 - mean_average_precision(0): 0.5729636414921225\n",
      "Epoch 15/30\n",
      "102/102 [==============================] - 10s 93ms/step - loss: 0.6922\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5783744243527581 - normalized_discounted_cumulative_gain@5(0): 0.6488263146716875 - mean_average_precision(0): 0.6027039316754507\n",
      "Epoch 16/30\n",
      "102/102 [==============================] - 9s 93ms/step - loss: 0.7043\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.576958171893885 - normalized_discounted_cumulative_gain@5(0): 0.6380426868498167 - mean_average_precision(0): 0.5901291471069952\n",
      "Epoch 17/30\n",
      "102/102 [==============================] - 10s 94ms/step - loss: 0.6872\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5922195585708011 - normalized_discounted_cumulative_gain@5(0): 0.6543608312163903 - mean_average_precision(0): 0.6131936829327296\n",
      "Epoch 18/30\n",
      "102/102 [==============================] - 10s 95ms/step - loss: 0.6789\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5799238427248672 - normalized_discounted_cumulative_gain@5(0): 0.6527265091630354 - mean_average_precision(0): 0.5988390375632215\n",
      "Epoch 19/30\n",
      "102/102 [==============================] - 10s 94ms/step - loss: 0.6718\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5760817675285786 - normalized_discounted_cumulative_gain@5(0): 0.6485728859276667 - mean_average_precision(0): 0.6007116289366847\n",
      "Epoch 20/30\n",
      "102/102 [==============================] - 10s 95ms/step - loss: 0.6341\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6047980949810488 - normalized_discounted_cumulative_gain@5(0): 0.6727226685761657 - mean_average_precision(0): 0.6215487673391377\n",
      "Epoch 21/30\n",
      "102/102 [==============================] - 10s 94ms/step - loss: 0.6507\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5998274098890153 - normalized_discounted_cumulative_gain@5(0): 0.6661844462066812 - mean_average_precision(0): 0.61276298204596\n",
      "Epoch 22/30\n",
      "102/102 [==============================] - 10s 93ms/step - loss: 0.6344\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5981574237277043 - normalized_discounted_cumulative_gain@5(0): 0.6649471218549592 - mean_average_precision(0): 0.6109707993785276\n",
      "Epoch 23/30\n",
      "102/102 [==============================] - 10s 95ms/step - loss: 0.6244\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6019711076180565 - normalized_discounted_cumulative_gain@5(0): 0.6640229378170236 - mean_average_precision(0): 0.61662933662434\n",
      "Epoch 24/30\n",
      "102/102 [==============================] - 10s 93ms/step - loss: 0.6020\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6046924623547137 - normalized_discounted_cumulative_gain@5(0): 0.6644429708867042 - mean_average_precision(0): 0.6149645742209033\n",
      "Epoch 25/30\n",
      "102/102 [==============================] - 10s 93ms/step - loss: 0.6220\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5992623655473802 - normalized_discounted_cumulative_gain@5(0): 0.6567146657797255 - mean_average_precision(0): 0.6160017273755174\n",
      "Epoch 26/30\n",
      "102/102 [==============================] - 10s 94ms/step - loss: 0.5983\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6130299335188328 - normalized_discounted_cumulative_gain@5(0): 0.6676286907935667 - mean_average_precision(0): 0.6221540139075508\n",
      "Epoch 27/30\n",
      "102/102 [==============================] - 10s 94ms/step - loss: 0.5943\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.60779665419059 - normalized_discounted_cumulative_gain@5(0): 0.6654200729693597 - mean_average_precision(0): 0.6145514752132399\n",
      "Epoch 28/30\n",
      "102/102 [==============================] - 10s 96ms/step - loss: 0.5580\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6258837610582709 - normalized_discounted_cumulative_gain@5(0): 0.6814206912650322 - mean_average_precision(0): 0.6323022576000982\n",
      "Epoch 29/30\n",
      "102/102 [==============================] - 10s 93ms/step - loss: 0.5666\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6210208831683093 - normalized_discounted_cumulative_gain@5(0): 0.6742247530256814 - mean_average_precision(0): 0.635696095980906\n",
      "Epoch 30/30\n",
      "102/102 [==============================] - 9s 89ms/step - loss: 0.5449\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.610642911325976 - normalized_discounted_cumulative_gain@5(0): 0.6684165773322793 - mean_average_precision(0): 0.6292750123074025\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
